############################################
# Experiment
############################################

import os
import sys
import pandas as pd
import numpy as np

sys.path.append('[PROJECT_PATH]/Code')
os.chdir('[PROJECT_PATH]/Code')

from utils import classify_tweets

import torch
print("GPU Availabel: ", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU Type: ", torch.cuda.get_device_name(0))

#----------------------
# Set Hyperparameters
#----------------------

MODEL_NAME_VARIABLE_MAP = {
    "Sub_Serious": ["1", "2"],
    'Sub_PolicyHealthcare': ["9", "10"],
    'Sub_Inequality': ["7"],
    'Sub_EvaluateTrump': ["31", "32", "33"],
    'Sub_EvaluateFed': ["28", "29", "30"],

    'Sub-StateOfEconomy': ["6"],
    'Sub_EvaluateGovernor': ["34", "35", "36"],
    'Sub_PolicyMasks': ['11', '12'],
    "Sub_EconomicRelief": ['21', '22']
}

MODEL_TYPE = 'Multi-class'

# SUBSET = "txt"
SUBSET = sys.argv[1]
print(f"Subset: {SUBSET}")

# RANDOM_SEED = 2
RANDOM_SEED = sys.argv[2]
print(f"Random seed: {RANDOM_SEED}")
RANDOM_SEED = int(RANDOM_SEED)


#----------------------
# Helper for data prep
#----------------------

def prepare_data(label_select, subset, random_seed):
    # Load codebook
    d_codebook = pd.read_csv(PATH_CODEBOOK)
    d_codebook['cat_label'] = d_codebook['category'] + ' - ' + d_codebook['label']
    codebook = [d_codebook.query(f"codebook_id == {lab}")['cat_label'].values[0] for lab in label_select]
    # Load coded tweets
    d = pd.read_csv(f'{PATH_DATA}/coded_label_wide_experiment.csv')
    # Shuffle the dataset
    d = d.sample(frac = 1, random_state = random_seed)
    # Get a dataset of unique tweet_id_str and split into training and validation sets
    tweet_id_str = d.tweet_id_str.unique()
    train_id, val_id = np.split(tweet_id_str, [int(0.8 * len(tweet_id_str))])

    # Split d into training and test set
    d_tr = d[d.tweet_id_str.isin(train_id)]
    d_va = d[d.tweet_id_str.isin(val_id)]
    # For the training set, subset depending on the required subset
    if subset == 'txt':
        d_tr = d_tr[d_tr['treatment'] == 0]
    else:
        d_tr = d_tr[d_tr['treatment'] == 1]
        
    # Bootstrap the training data (removed)
    # d_tr = d_tr.sample(frac = 1, replace = True, random_state = random_seed)

    # For the validation set, always use the treated group
    d_va = d_va[d_va['treatment'] == 1]
    # Clean labels
    d_tr['labels'] = list(d_tr[label_select].values)
    d_tr = d_tr[['tweet_id_str', 'text', 'labels']].copy()
    d_va['labels'] = list(d_va[label_select].values)
    d_va = d_va[['tweet_id_str', 'text', 'labels']].copy()
    # Return outputs
    return d_tr, d_va, codebook


for MODEL_NAME in MODEL_NAME_VARIABLE_MAP.keys():
    LABELS = MODEL_NAME_VARIABLE_MAP[MODEL_NAME]

    model_dir = f"[CLASSIFIERS_PATH]/{MODEL_NAME}/{MODEL_TYPE} {MODEL_NAME} {SUBSET} {RANDOM_SEED}"
    os.makedirs(model_dir, exist_ok=True)

    print(f"Model name: {MODEL_NAME}")
    print(f"Label: ", LABELS)
    print(f"Model directory: {model_dir}")

    #----------------------
    # Load input data (simple)
    #----------------------

    PATH_DATA = '../Data'
    PATH_CODEBOOK = '../Data/coded_codebook.csv'

    # Get the dataset of interest
    d_tr, d_va, codebook = prepare_data(LABELS, subset = SUBSET, random_seed = RANDOM_SEED)

    print("ID training set: ", d_tr.tweet_id_str.unique(), len(d_tr.tweet_id_str.unique()))
    print("ID validation set: ", d_va.tweet_id_str.unique(), len(d_va.tweet_id_str.unique()))
    # Check if d_va.tweet_id_str is in d_tr.tweet_id_str for data leakage
    print("Check for data leakage: ", d_va.tweet_id_str.isin(d_tr.tweet_id_str).sum())

    d = (d_tr, d_va)

    print(d)
    print(codebook)

    #----------------------------
    # Train classifier (simple)
    #----------------------------

    worker = classify_tweets(
        transformer_spec = ['roberta', 'roberta-large'],
        model_name = f"{MODEL_TYPE} {MODEL_NAME} {SUBSET}",
        model_directory = model_dir,
        model_type = MODEL_TYPE,
        n_epochs = 20,
        data = d,
        codebook = codebook,
        random_seed = RANDOM_SEED,
        weighted = True
        )

    worker.data_setup()
    worker.setup_model(save_model_every_epoch = False, save_no_model = True)
    worker.train_model()
